# -*- coding: utf-8 -*- #
"""
Time                2023/4/27 16:47
Author:             mingfeng (SunnyQjm)
Email               mfeng@linux.alibaba.com
File                xreadis_hash_table.py
Description:
"""
import os
from redis import Redis
from redis.client import Pipeline
from pathlib import Path
from typing import Optional, Union, Awaitable, List


class XRedisHashTable:
    """
    An extended Redis Hash Table implement by lua script, features:
    1. Support set expire time for each filed
    """

    def __init__(self, redis_cli: Redis, lua_script_base_dir: str = None):
        if lua_script_base_dir is None:
            lua_script_base_dir = Path(__file__).resolve().parent
        self._redis = redis_cli
        self._lua_script_base_dir = lua_script_base_dir
        self._script_hashes = {

        }

    def _load_lua_scripts_from_file(self, name: str):
        lua_script_file = f"{os.path.join(self._lua_script_base_dir, name + '.lua')}"
        with open(lua_script_file) as file:
            script_value = file.read()
            script_hash = self._redis.script_load(script_value)
            return script_hash

    def _get_or_cache_script_sha(self, name: str):
        target_hash = self._script_hashes.get(name, "")
        if not target_hash or not self._redis.script_exists(target_hash)[0]:
            # Need load script
            target_hash = self._load_lua_scripts_from_file(name)
            self._script_hashes[name] = target_hash
        return target_hash

    def _evalsha(self, name: str, *fields) -> Union[Awaitable[str], str]:
        return self._redis.evalsha(
            self._get_or_cache_script_sha(name), *fields
        )

    def _eval_by_pl(self, pl: Pipeline, name: str, *fields) -> Pipeline:
        return pl.evalsha(
            self._get_or_cache_script_sha(name), *fields
        )

    @staticmethod
    def _get_expire_table(table_name):
        return "{" + table_name + "}:xhash_expire_time"

    def hget(self, table_name: str, field: str) -> \
            Optional[str]:
        res = self._evalsha("x_hget", 2, table_name,
                            self._get_expire_table(table_name), field)
        return None if res is None else res

    def hget_pl(self, pl: Pipeline, table_name: str, field: str) -> Pipeline:
        return self._eval_by_pl(pl, "x_hget", 2, table_name,
                                self._get_expire_table(table_name), field)

    def hgetall(self, table_name: str) -> dict:
        kvs = self._evalsha("x_hgetall", 2, table_name,
                            self._get_expire_table(table_name))
        res = {}
        for i in range(0, len(kvs), 2):
            res[kvs[i]] = kvs[i + 1]
        return res

    def hset(self, table_name: str, *kvs: str,
             expire: int = -1) -> bool:
        try:
            res = self._evalsha("x_hset", 2, table_name,
                                self._get_expire_table(table_name), expire,
                                *kvs)
            return res == "OK"
        except Exception as err:
            print(err)
            return False

    def hdrop_table(self, table_name: str) -> bool:
        res = self._evalsha("x_hdrop_table", 2, table_name,
                            self._get_expire_table(table_name))
        return res == "OK"

    def hdel(self, table_name: str, *fields: str) -> bool:
        res = self._evalsha("x_hdel", 2, table_name,
                            self._get_expire_table(table_name), *fields)
        return res == "OK"

    def hexpire(self, table_name: str, field: str, expire: int = -1) -> bool:
        res = self._evalsha("x_hexpire", 2, table_name,
                            self._get_expire_table(table_name), field, expire)
        return res == "OK"

    def hlen(self, table_name: str) -> int:
        res = self._evalsha("x_hlen", 2, table_name,
                            self._get_expire_table(table_name))
        return res

    def hkeys(self, table_name: str) -> List[str]:
        res = self._evalsha("x_hkeys", 2, table_name,
                            self._get_expire_table(table_name))
        return res
